{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d093f55-b308-47d5-b255-a47f1df1b2df",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import os\n",
    "import sys\n",
    "\n",
    "# Add the project's files to the python path\n",
    "# file_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))  # for .py script\n",
    "file_path = os.path.dirname(os.path.abspath(''))  # for .ipynb notebook\n",
    "sys.path.append(file_path)\n",
    "\n",
    "# Necessary for advanced config parsing with hydra and omegaconf\n",
    "from omegaconf import OmegaConf\n",
    "OmegaConf.register_new_resolver(\"eval\", eval)\n",
    "\n",
    "import hydra\n",
    "from src.utils import init_config\n",
    "import torch\n",
    "from src.visualization import show\n",
    "from src.datasets.dales import CLASS_NAMES, CLASS_COLORS\n",
    "from src.datasets.dales import DALES_NUM_CLASSES as NUM_CLASSES\n",
    "from src.transforms import *"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3496cc8f-7fa4-481b-b4f6-5dc48993494a",
   "metadata": {},
   "source": [
    "## Parsing the config files\n",
    "Hydra and OmegaConf are used to parse the `yaml` config files.\n",
    "\n",
    "❗Make sure to **set the path to a relevant ckpt file**. \n",
    "You can use our pretrained models for this."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d526f72f-4585-4ff9-9ce1-a7c1885e572f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Parse the configs using hydra\n",
    "cfg = init_config(overrides=[\n",
    "    \"experiment=dales\",\n",
    "    \"ckpt_path=path/to/your/checkpoint.ckpt\"\n",
    "])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e4e98af-a7eb-4b40-b8e3-8e64d70498f5",
   "metadata": {},
   "source": [
    "## Datamodule and model instantiation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d06f3d5d-d42d-449c-b494-f5f5c71be4e4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Instantiate the datamodule\n",
    "datamodule = hydra.utils.instantiate(cfg.datamodule)\n",
    "datamodule.prepare_data()\n",
    "datamodule.setup()\n",
    "\n",
    "# Instantiate the model\n",
    "model = hydra.utils.instantiate(cfg.model)\n",
    "\n",
    "# Load pretrained weights from a checkpoint file\n",
    "model = model.load_from_checkpoint(cfg.ckpt_path, net=model.net, criterion=model.criterion)\n",
    "model = model.eval().cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f95ac9f4-3266-47b9-909d-b189dabeec2d",
   "metadata": {},
   "source": [
    "## Hierarchical partition loading and inference\n",
    "SPT can process very large scenes at once.\n",
    "Depending on the dataset stage you use (train, val, or test), the inference will be run on a whole million-point tile or on a spherical sampling of it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4de618d2-d809-4c6a-b38c-7a97f66b3615",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Pick among train, val, and test datasets. It is important to note that\n",
    "# the train dataset produces augmented spherical samples of large \n",
    "# scenes, while the val and test dataset\n",
    "# dataset = datamodule.train_dataset\n",
    "dataset = datamodule.val_dataset\n",
    "# dataset = datamodule.test_dataset\n",
    "\n",
    "# For the sake of visualization, we require that NAGAddKeysTo does not \n",
    "# remove input Data attributes after moving them to Data.x, so we may \n",
    "# visualize them\n",
    "for t in dataset.on_device_transform.transforms:\n",
    "    if isinstance(t, NAGAddKeysTo):\n",
    "        t.delete_after = False\n",
    "\n",
    "# Load a dataset item. This will return the hierarchical partition of an \n",
    "# entire tile, within a NAG object \n",
    "nag = dataset[0]\n",
    "\n",
    "# Apply on-device transforms on the NAG object. For the train dataset, \n",
    "# this will select a spherical sample of the larger tile and apply some\n",
    "# data augmentations. For the validation and test datasets, this will\n",
    "# prepare an entire tile for inference\n",
    "nag = dataset.on_device_transform(nag.cuda())\n",
    "\n",
    "# Inference\n",
    "logits = model(nag)\n",
    "\n",
    "# If the model outputs multi-stage predictions, we take the first one, \n",
    "# corresponding to level-1 predictions \n",
    "if model.multi_stage_loss:\n",
    "    logits = logits[0]\n",
    "\n",
    "# Compute the level-0 (pointwise) predictions based on the predictions\n",
    "# on level-1 superpoints\n",
    "l1_preds = torch.argmax(logits, dim=1).detach()\n",
    "l0_preds = l1_preds[nag[0].super_index]\n",
    "\n",
    "# Save predictions for visualization in the level-0 Data attributes \n",
    "nag[0].pred = l0_preds"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c1ab511-dc3a-4750-bde8-5c6fcae7a516",
   "metadata": {},
   "source": [
    "## Visualizing an entire tile\n",
    "SPT can process very large scenes at once. Let's visualize the output."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ae5bec5-01d2-41e3-90ec-244a61cac5d3",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Visualize the hierarchical partition\n",
    "show(\n",
    "    nag, \n",
    "    class_names=CLASS_NAMES, \n",
    "    ignore=NUM_CLASSES,\n",
    "    class_colors=CLASS_COLORS,\n",
    "    max_points=100000\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c986a13-ee00-4538-aa7d-26c5d4077dc5",
   "metadata": {},
   "source": [
    "However, for memory reasons, the visualization cannot display all points. Let's have a look at a smaller area."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15a653dd-291c-43c6-8c83-78ebff2d00b9",
   "metadata": {},
   "source": [
    "## Selecting a portion of the hierarchical partition\n",
    "The NAG structure can be subselected using `nag.select()`.\n",
    "\n",
    "This function expects an `int` specifying the partition level from which we should select, along with an index or a mask in the form or a `list`, `numpy.ndarray`, `torch.Tensor`, or `slice`.\n",
    "This index/mask describes which nodes to select at the specified level.\n",
    "\n",
    "The output NAG will only contain children, parents and edges of the selected nodes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d86de234-3da8-4d10-a373-7dce47c514da",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Pick a center and radius for the spherical sample\n",
    "center = torch.tensor([[40, 115, 0]]).to(nag.device)\n",
    "radius = 10\n",
    "\n",
    "# Create a mask on level-0 (ie points) to be used for indexing the NAG \n",
    "# structure\n",
    "mask = torch.where(torch.linalg.norm(nag[0].pos - center, dim=1) < radius)[0]\n",
    "\n",
    "# Subselect the hierarchical partition based on the level-0 mask\n",
    "nag_visu = nag.select(0, mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ef97ce1-527d-418d-b273-839c721ede6e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Visualize the sample\n",
    "show(\n",
    "    nag_visu,\n",
    "    class_names=CLASS_NAMES,\n",
    "    ignore=NUM_CLASSES,\n",
    "    class_colors=CLASS_COLORS, \n",
    "    max_points=100000\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "abf42a3e-2ba8-46f7-9eda-a9a22bc05674",
   "metadata": {},
   "source": [
    "## Visualizing the superpoint graphs\n",
    "Let's have a closer look to visualize the graph connecting superpoints by setting `centroids=True` and `h_edge=True`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b9f9ebf-6549-4c7b-84dd-ab361d4f201a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Visualize the sample\n",
    "show(\n",
    "    nag_visu,\n",
    "    class_names=CLASS_NAMES,\n",
    "    ignore=NUM_CLASSES,\n",
    "    class_colors=CLASS_COLORS, \n",
    "    max_points=100000, \n",
    "    centroids=True, \n",
    "    h_edge=True, \n",
    "    h_edge_width=2\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "08758db4-d865-4c84-b9f8-3adca6d3fdef",
   "metadata": {},
   "source": [
    "## Side-by-side visualization mode\n",
    "By setting `gap` to a chosen 3D offset, we can visualize all partition levels at once. Besides, setting `v_edge=True` will display the vertical edges connecting superpoints with their children."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7cd261bc-486e-4c58-acaa-9a37ed24444f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Visualize the sample\n",
    "show(\n",
    "    nag_visu,\n",
    "    figsize=1000,\n",
    "    class_names=CLASS_NAMES,\n",
    "    ignore=NUM_CLASSES,\n",
    "    class_colors=CLASS_COLORS, \n",
    "    max_points=100000, \n",
    "    centroids=True, \n",
    "    v_edge=True, \n",
    "    v_edge_width=1, \n",
    "    gap=[0, 0, 10]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "def3d228-29fd-466a-b27d-e45b4de77429",
   "metadata": {},
   "source": [
    "## Exporting your visualization to HTML\n",
    "You can export your interactive visualization to HTML. \n",
    "You can then share your visualization, to be opened on any web browser with internet connection.\n",
    "\n",
    "To export a visualization, simply specify a `path` to which the file should be saved.\n",
    "Additionally, you may set a `title` to be displayed in your HTML."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be0113c8-b4a2-4033-b49f-49ff6cbbe3da",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Visualize the sample\n",
    "show(\n",
    "    nag_visu,\n",
    "    figsize=1600,\n",
    "    class_names=CLASS_NAMES,\n",
    "    ignore=NUM_CLASSES,\n",
    "    class_colors=CLASS_COLORS, \n",
    "    max_points=100000,\n",
    "    title=\"My Interactive Visualization Partition\", \n",
    "    path=\"my_interactive_visualization.html\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "69df8eb7-d7f5-4e98-86ee-8dc231939b2a",
   "metadata": {},
   "source": [
    "## Going further with visualization\n",
    "See the commented code in `src.visualization` for more visualization options."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:spt] *",
   "language": "python",
   "name": "conda-env-spt-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
